/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include "fmha.h"

inline __device__ float4 ldg128(const void *ptr) {
    return *static_cast<const float4 *>(ptr);
}

inline __device__ void stg128(void *ptr, const float4 &data) {
    *static_cast<float4 *>(ptr) = data;
}

template<typename T, int THREADS, int HIDDEN_SIZE, int CHUNKS>
__global__ __launch_bounds__(THREADS) void fmha_noloop_reduce_kernel(void *__restrict__ out,
                                                                     const void *__restrict__ in,
                                                                     const int *__restrict__ cu_seqlens,
                                                                     const int batch_size) {

    enum { BYTES_PER_LDG = 16 };
    enum { NUM_ELTS = BYTES_PER_LDG / sizeof(T) };

    // One CTA hidden vector for K and V
    enum { BYTES_PER_ROW = HIDDEN_SIZE * sizeof(T) * 2 };
    // The stride in bytes in dQKV
    enum { OUT_STRIDE_BYTES = 3 * HIDDEN_SIZE * sizeof(T) };
    // The offset in bytes in dQKV to the dKV part for non-interleaved heads
    enum { OUT_OFFSET_KV_BYTES = HIDDEN_SIZE * sizeof(T) };

    static_assert(BYTES_PER_ROW == HIDDEN_SIZE * 2 * sizeof(T)); 

    // Size in bytes of the input tile
    enum { BYTES_PER_TILE = CHUNKS * BYTES_PER_ROW };

    enum { BYTES_PER_CTA = THREADS * BYTES_PER_LDG };

    enum { LDGS = BYTES_PER_ROW / BYTES_PER_CTA };
    static_assert(BYTES_PER_CTA * LDGS == BYTES_PER_ROW);

    union Vec_t {
        float4 raw;
        T elt[NUM_ELTS];
    };

    // ZERO-OUT invalid positions in dQKV
    const int total = cu_seqlens[batch_size];
    if(blockIdx.x >= total){
        enum { BYTES_PER_QKV_ROW = 3 * HIDDEN_SIZE * sizeof(T) };
        enum { STGS = BYTES_PER_QKV_ROW / BYTES_PER_LDG };

        const float4 zeros = make_float4(0.f, 0.f, 0.f, 0.f);

        char *base_ptr = static_cast<char *>(out) + blockIdx.x * OUT_STRIDE_BYTES;

        for(int tidx = threadIdx.x; tidx < STGS; tidx += THREADS){
            stg128(base_ptr + tidx * BYTES_PER_LDG, zeros);
        }

        return;
    }

    // SETUP
    const int offset_in = blockIdx.x * BYTES_PER_TILE + threadIdx.x * BYTES_PER_LDG;
    const char *ptr_in = static_cast<const char *>(in) + offset_in;

    const int offset_out = blockIdx.x * OUT_STRIDE_BYTES + threadIdx.x * BYTES_PER_LDG;
    char *ptr_out = static_cast<char *>(out) + OUT_OFFSET_KV_BYTES + offset_out;

    // LOAD

    Vec_t local_in[CHUNKS][LDGS];

    #pragma unroll
    for( int c = 0; c < CHUNKS; c++ ) {
        #pragma unroll
        for( int l = 0; l < LDGS; l++ ) {
            int offset = c * BYTES_PER_ROW + l * BYTES_PER_CTA;
            local_in[c][l].raw = ldg128(ptr_in + offset);
        }
    }

    // UNPACK
    float acc[LDGS][NUM_ELTS];

    #pragma unroll
    for( int l = 0; l < LDGS; l++ ) {
        #pragma unroll
        for( int e = 0; e < NUM_ELTS; e++ ) {
            acc[l][e] = float(local_in[0][l].elt[e]);
        }
    }

    // COMPUTE
    #pragma unroll
    for( int c = 1; c < CHUNKS; c++ ) {
        #pragma unroll
        for( int l = 0; l < LDGS; l++ ) {
            #pragma unroll
            for( int e = 0; e < NUM_ELTS; e++ ) {
                acc[l][e] += float(local_in[c][l].elt[e]);
            }
        }
    }

    // PACK
    Vec_t local_out[LDGS];

    #pragma unroll
    for( int l = 0; l < LDGS; l++ ) {
        #pragma unroll
        for( int e = 0; e < NUM_ELTS; e++ ) {
            local_out[l].elt[e] = T(acc[l][e]);
        }
    }

    // STORE
    #pragma unroll
    for( int l = 0; l < LDGS; l++ ) {
        const int offset = l * BYTES_PER_CTA;
        stg128(ptr_out + offset, local_out[l].raw);
    }
}

void fmha_run_noloop_reduce(void *out,
                            const void *in,
                            const int *cu_seqlens,
                            const int hidden_size,
                            const int batch_size,
                            const int total,
                            const int num_chunks,
                            cudaStream_t stream) {

    const int blocks = total;

    if(hidden_size == 1024){

        constexpr int HIDDEN_SIZE = 1024;
        constexpr int THREADS = 256;

        if( num_chunks == 2 ) {
            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 2><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
        } else if( num_chunks == 3 ) {
            fmha_noloop_reduce_kernel<half, THREADS, HIDDEN_SIZE, 3><<<blocks, THREADS, 0, stream>>>(out, in, cu_seqlens, batch_size);
        } else {
            assert(false && "Unsupported num_chunks");
        }

    }else{
        assert(false && "Unsupported hidden_size");
    }

    FMHA_CHECK_CUDA(cudaPeekAtLastError());
}
